/*    Copyright 2010 Tobias Marschall
 *
 *    This file is part of MoSDi.
 *
 *    MoSDi is free software: you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation, either version 3 of the License, or
 *    (at your option) any later version.
 *
 *    MoSDi is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with MoSDi.  If not, see <http://www.gnu.org/licenses/>.
 */

package mosdi.fa;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;

import mosdi.util.BitArray;
import mosdi.util.Log;

/** Represents a counting deterministic finite automaton. */
public class CDFA implements MinimizableAutomaton {
	private Alphabet alphabet;
	private List<int[]> transitionTable;
	private List<Integer> outputTable;
	
	/** Constructor.
	 *  @param transitionTable List that contains one entry for each state, where an entry is
	 *    a array of integers of size alphabet.size(), i.e. transitionTable.get(state)[char] 
	 *    gives the target state for a transition from "state" under "char". 
	 */
	public CDFA(Alphabet alphabet, List<int[]> transitionTable, List<Integer> outputTable) {
		this.alphabet = alphabet;
		this.transitionTable = transitionTable;
		this.outputTable = outputTable;
	}

	/** Count matches.
	 *  @return Returns the sum of outputs of visited states.
	 *     Returns -1 if string led to an invalid state. */
	public int countMatches(String s) {
		int unknownIndex = alphabet.getIndex('#');
		int node = 0;
		int count = outputTable.get(0);
		for (int i=0; i<s.length(); ++i) {
			int c = alphabet.getIndex(s.charAt(i));
			// map all other (those not in the pattern) characters onto #
			if (c<0) {
				if (unknownIndex>=0) c=unknownIndex;
				else throw new IllegalArgumentException(String.format("Unknown character \"%s\" encountered", c));
			}
			node = transitionTable.get(node)[c];
			if (node<0) return -1;
			count+=outputTable.get(node);
		}
		return count;
	}

	/** Count matches.
	 *  @return Returns the sum of outputs of visited states. */
	public int countMatches(int[] s) {
		return countMatches(s, false);
	}

	/** Count matches.
	 *  @param restartAtUnknownChars If true, the automaton moves to start state for
	 *                               unknown characters (c<0).
	 *  @return Returns the sum of outputs of visited states.
	 *     Returns -1 if string led to an invalid state. */
	public int countMatches(int[] s, boolean restartAtUnknownChars) {
		int unknownIndex = alphabet.getIndex('#');
		int node = 0;
		int count = outputTable.get(0);
		for (int i=0; i<s.length; ++i) {
			int c = s[i];
			// map all other (those not in the pattern) characters onto #
			if (c<0) {
				if (restartAtUnknownChars) { 
					node = 0;
					continue;
				}
				if (unknownIndex>=0) c=unknownIndex;
				else throw new IllegalArgumentException(String.format("Unknown character \"%s\" encountered", c));
			}
			node = transitionTable.get(node)[c];
			if (node<0) return -1;
			count+=outputTable.get(node);
		}
		return count;
	}

	/**
	 * 
	 * @param s a String composed of characters the CDFA can understand 
	 * @return a {0,1}^s.length-String where 1 denotes a matching of the CDFA at the coresponding position
	 */
	public String getOutputString(String s) {
		StringBuilder sb = new StringBuilder(s.length());
		//
		int unknownIndex = alphabet.getIndex('#');
		int node = 0;
		// int count = outputTable.get(0);
		for (int i=0; i<s.length(); ++i) {
			int c = alphabet.getIndex(s.charAt(i));
			// map all other (those not in the pattern) characters onto #
			if (c<0) {
				if (unknownIndex>=0) c=unknownIndex;
				else throw new IllegalArgumentException(String.format("Unknown character \"%s\" encountered", c));
			}
			
			node = transitionTable.get(node)[c];
			if (node<0) throw new RuntimeException(String.format("Wrong transition to state %d", node));
			
			sb.append(outputTable.get(node));
			//count+=outputTable.get(node);
		//
		}
		return sb.toString();
	}

	public static class MatchPosition {
		private int position;
		private int state;
		MatchPosition(int position, int state) {
			this.position=position;
			this.state=state;
		}
		public int getPosition() { return position; }
		public int getState() { return state; }
	}
	
	public boolean equals(Object anotherCDFA) {
		if(anotherCDFA instanceof CDFA) {
			return equals((CDFA)anotherCDFA);
		} else {
			return false; 
		}
	}
	
	/**
	 * A language-equivalence-test for CDFAs.<br>
	 * It takes at most<br>
	 * O(this.getStateCount() * anotherDFA.getStateCount() * alphabet.size())
	 * @author Wolfgang
	 * @param anotherDFA a CDFA understanding the same alphabet as this CDFA
	 * @return true iff this CDFA is language-equivalent with anotherDFA
	 */
	public boolean equals(CDFA anotherDFA) {
		if(Math.signum(outputTable.get(0))== Math.signum(anotherDFA.outputTable.get(0))) {
			boolean seen[][] = new boolean[outputTable.size()][anotherDFA.outputTable.size()];
			
			LinkedList<int[]> queue = new LinkedList<int[]>();
			
			int[] tmpStatePair = new int[2];
			tmpStatePair[0] = tmpStatePair[1] = 0; //assuming that initial states are allways 0
			
			int[] tmpTargetStatePair = new int[2];
						
			queue.add(tmpStatePair);  //allways add states with equivalent acceptance-behavior

			while(!queue.isEmpty()) {
				tmpStatePair = queue.removeFirst();
				if(!seen[tmpStatePair[0]][tmpStatePair[1]]) {
					for(char c : alphabet) {
						tmpTargetStatePair = new int[2];
						tmpTargetStatePair[0] = this.transitionTable.get(tmpStatePair[0])[alphabet.getIndex(c)];
						tmpTargetStatePair[1] = anotherDFA.transitionTable.get(tmpStatePair[1])[alphabet.getIndex(c)];
						if(!seen[tmpTargetStatePair[0]][tmpTargetStatePair[1]]) { //not doing loops
							if(Math.signum(outputTable.get(tmpTargetStatePair[0]))== Math.signum(anotherDFA.outputTable.get(tmpTargetStatePair[1]))) {
								queue.add(tmpTargetStatePair);
							} else {
								return false;
							}
						}
					}
				}
				seen[tmpStatePair[0]][tmpStatePair[1]] = true;
			}
			return true;
		}
		else {//initial states don't have the same acceptance-behavior
			return false; 
		}
	}
	
	/** Finds matches and returns a list of position were matches end.
	 *  @return Returns null if string led to an invalid state.
	 */
	public List<MatchPosition> findMatchPositions(String s) {
		Log.startTimer();
		List<MatchPosition> result = new ArrayList<MatchPosition>();
		int unknownIndex = alphabet.getIndex('#');
		
		int node = 0;
		for (int i=0; i<s.length(); ++i) {
			int c = alphabet.getIndex(s.charAt(i));
			// map all other (those not in the pattern) characters onto #
			if (c<0) {
				if (unknownIndex>=0) c=unknownIndex;
				else throw new IllegalArgumentException(String.format("Unknown character \"%s\" encountered", c));
			}
			
			node = transitionTable.get(node)[c];
			if (node<0) return null;
			
			// Log.printInsane(String.format("%d(%s): %d", i, c, node.getIndex()));
			if (outputTable.get(node)>0) {
				MatchPosition mp = new MatchPosition(i, node);
				result.add(mp);
			}
		}
		Log.stopTimer("dfa matching");
		return result;
	}
	
	/** Finds matches and returns a list of position were matches end.
	 *  @return Returns null if string led to an invalid state.
	 */
	public List<MatchPosition> findMatchPositions(int[] s) {
		Log.startTimer();
		List<MatchPosition> result = new ArrayList<MatchPosition>();
		int node = 0;
		for (int i=0; i<s.length; ++i) {
			node = transitionTable.get(node)[s[i]];
			if (node<0) return null;
			if (outputTable.get(node)>0) {
				MatchPosition mp = new MatchPosition(i, node);
				result.add(mp);
			}
		}
		Log.stopTimer("dfa matching");
		return result;
	}

	/** Finds the next match (starting from given position) and returns the 
	 *  position were the found match ends.
	 *  @return Returns null if string led to an invalid state or no match was found. 
	 */
	public MatchPosition findNextMatchPosition(String s, int startpos) {
		int unknownIndex = alphabet.getIndex('#');
		int node = 0;
		for (int i=startpos; i<s.length(); ++i) {
			int c = alphabet.getIndex(s.charAt(i));
			// map all other (those not in the pattern) characters onto #
			if (c<0) {
				if (unknownIndex>=0) c=unknownIndex;
				else throw new IllegalArgumentException(String.format("Unknown character \"%s\" encountered", c));
			}
			
			node = transitionTable.get(node)[c];
			if (node<0) return null;
			
			if (outputTable.get(node)>0) return new MatchPosition(i, node);
		}
		return null;
	}
	
	/** Minimize the CDFA. */
	public CDFA minimize() { return minimizeHopcroft(); }
	
	/** Minimize the CDFA using Hopcroft's algorithm. */
	public CDFA minimizeHopcroft() {
		Log.startTimer();
		Partition partition = AutomataUtils.minimize(this);
		// finally construct new automaton
		// map from block index to new state number
		partition.renumberBlocks();
		// create new output table
		Integer[] newOutputTable = new Integer[partition.blockCount()];
		List<int[]> newTransitionTable = new ArrayList<int[]>(partition.blockCount());
		for (int i=0; i<outputTable.size(); ++i) {
			int newState = partition.getBlockIndex(i);
			if (newOutputTable[newState]!=null) continue;
			int oldState = partition.block(newState).iterator().next();
			newOutputTable[newState]=outputTable.get(oldState);
			int[] t = new int[alphabet.size()];
			for (int c=0; c<alphabet.size(); ++c) {
				t[c] = partition.getBlockIndex(transitionTable.get(oldState)[c]); 
			}
			if (newState != newTransitionTable.size()) throw new IllegalStateException();
			newTransitionTable.add(t);
		}
		
		Log.printf(Log.Level.DEBUG, "minimization (hopcroft): %d states%n", partition.blockCount());
		Log.stopTimer("minimize automaton (hopcroft)");
		
		return new CDFA(alphabet, newTransitionTable, Arrays.asList(newOutputTable));
	}

	/** Perform DFA minimization according to "Kozen, Automata and Compatibility".
	 *  Assumes that this is a DFA. */
	public CDFA minimizeKozen() {
		Log.startTimer();
		int nodeCount = outputTable.size();
	
		/** Set entry to true if the two nodes are not equivalent. */
		BitArray[] notEquivalent = new BitArray[nodeCount];
		// create (triangle) matrix
		for (int i=0; i<nodeCount; ++i) {
			notEquivalent[i]=new BitArray(i);
		}
		// initialize 
		// TODO: more efficient using bucket-something
		for (int i=0; i<nodeCount; ++i) {
			for (int j=0; j<i; ++j) {
				if (outputTable.get(i)!=outputTable.get(j)) {
					notEquivalent[i].set(j, true);
				}
			}
		}
		// main loop: build equivalence table
		boolean changeMade = true;
		while (changeMade) {
			changeMade = false;
			for (int i=0; i<nodeCount; ++i) {
				for (int j=0; j<i; ++j) {
					if (notEquivalent[i].get(j)) continue;
					for (int c=0; c<alphabet.size(); ++c) {
						int tI = this.transitionTable.get(i)[c];
						int tJ = this.transitionTable.get(j)[c];
						if (tI==tJ) continue;
						if (tI<tJ) {
							int tmp = tI;
							tI=tJ;
							tJ=tmp;
						}
						if (notEquivalent[tI].get(tJ)) {
							notEquivalent[i].set(j, true);
							changeMade=true;
							continue;
						}
					}
				}
			}
			
//			for (int i=0; i<notEquivalent.length; ++i) {
//				StringBuffer sb = new StringBuffer();
//				sb.append(String.format("%03d: ", i));
//				for (int j=0; j<i; ++j) {
//					sb.append(String.format("%d, ", notEquivalent[i].get(j)?1:0));
//				}
//				Log.printInsane(sb.toString());
//			}
		}
		// create table to know which nodes should be substituted by others
		int substitutionTable[] = new int[nodeCount];
		int[] nodeMap = new int[nodeCount];
		for (int i=0; i<nodeCount; ++i) substitutionTable[i]=-1;
		int nodesNeeded = 0;
		for (int i=0; i<nodeCount; ++i) {
			if (substitutionTable[i]!=-1) continue;
			nodeMap[i]=nodesNeeded;
			++nodesNeeded;
			// use -2 to indicate that current node is to be used as representant for
			// equivalence class (because its the node in this class with the lowest
			// index)
			substitutionTable[i]=-2;
			// check which of the rest of the nodes fall into this equivalence class
			for (int j=i+1; j<nodeCount; ++j) {
				if (!notEquivalent[j].get(i)) substitutionTable[j]=i;
			}
		}
		// create new node table
		Integer[] newOutputTable = new Integer[nodesNeeded];
		List<int[]> newTransitionTable = new ArrayList<int[]>(nodesNeeded);
		int n = 0;
		for (int i=0; i<nodeCount; ++i) {
			if (substitutionTable[i]!=-2) continue;
			newOutputTable[n]=outputTable.get(i);
			int[] transitions = new int[alphabet.size()];
			for (int c=0; c<alphabet.size(); ++c) {
				int originalTarget = transitionTable.get(i)[c];
				int k = (substitutionTable[originalTarget]==-2)?originalTarget:substitutionTable[originalTarget];
				transitions[c]=nodeMap[k];
			}
			newTransitionTable.add(transitions);
			++n;
		}

		Log.printf(Log.Level.VERBOSE, "minimization (kozen): %d%n", nodesNeeded);
		Log.stopTimer("minimize automaton (kozen)");
		
		return new CDFA(alphabet, newTransitionTable, Arrays.asList(newOutputTable));
//		for (int i=0; i<notEquivalent.length; ++i) {
//			StringBuffer sb = new StringBuffer();
//			sb.append(String.format("%03d: ", i));
//			for (int j=0; j<i; ++j) {
//				sb.append(String.format("%d, ", notEquivalent[i].get(j)?1:0));
//			}
//			Log.printInsane(sb.toString());
//		}
//
//		StringBuffer sb = new StringBuffer();
//		sb.append("substitutionTable: ");
//		for (int i=0; i<substitutionTable.length; ++i) {
//			sb.append(String.format("%d:%d, ", i, substitutionTable[i]));
//		}
//		Log.printInsane(sb.toString());
	}

	/** Greatest common divisor. */
	static private int gcd(int a, int b) {
		if (a<b) return gcd(b,a);
		while (b!=0) {
			int tmp=b;
			b=a%b;
			a=tmp;
		}
		return a;
	}

	/** Experimental method that outputs the periodicity of each state and returns
	 *  the number of states with periodicity!=0. */
	public int analysePeriodicity() {
		int result = 0;
		// loop over all states and analyze for each its periodicity
		for (int state=0; state<transitionTable.size(); ++state) {
			List<Integer> tmpList = new ArrayList<Integer>();

			int period = -1;
			HashSet<Integer> reachable = new HashSet<Integer>();
			reachable.add(state);
			int n = 0;
			while ((period!=1) && (n<1000)) {
				HashSet<Integer> nextReachable = new HashSet<Integer>();
				for (int i : reachable) {
					for (int j : transitionTable.get(i)) nextReachable.add(j);
				}
				++n;
				reachable=nextReachable;
				if (reachable.contains(state)) {
					if (period==-1) {
						period=n;
					} else {
						period=gcd(n,period);
					}
					tmpList.add(n);
				}
			}
			StringBuffer sb = new StringBuffer();
			sb.append(String.format("%d: %d (", state, period));
			for (int i : tmpList) {
				sb.append(i);
				sb.append(',');
			}
			sb.append(')');
			Log.println(Log.Level.DEBUG, sb.toString());
			if (period!=1) ++result;
		}
		return result;
	}

	/** Naive (slow) approach to asymptotic equivalence.
	 *  WARNING: This is experimental stuff.
	 */
	public Partition analyseAsymptoticEquivalence() {
		// determinte final states
		BitArray finalStates = new BitArray(transitionTable.size());
		for (int state=0; state<outputTable.size(); ++state) {
			if (outputTable.get(state)==0) continue;
			if (outputTable.get(state)==1) {
				finalStates.set(state, true);
			} else {
				throw new IllegalStateException("Not implemented for output(state)>1");
			}
		}
		
		// for each state the lengths of the strings leading to a final state
		BitArray[] lengthsArray = new BitArray[transitionTable.size()];
		final int maxLength = 63;
		
		// loop over all states and analyze for each its periodicity
		for (int state=0; state<transitionTable.size(); ++state) {
			// set of states currently active
			BitArray stateSet = new BitArray(transitionTable.size());
			stateSet.set(state, true);
			// possible lengths of strings leading from current state to a final state
			BitArray lengths = new BitArray(maxLength+1);
			if (finalStates.get(state)) lengths.set(0, true);
			
			// true if the state does not change any more
			boolean stable = false;
			for (int i=1; i<=maxLength; ++i) {
				// update state set
				if (!stable) {
					BitArray newStateSet = new BitArray(transitionTable.size());
					int sourceState = -1;
					for (boolean b : stateSet) {
						++sourceState;
						if (!b) continue;
						for (int targetState : transitionTable.get(sourceState)) newStateSet.set(targetState, true); 
					}
					if (newStateSet.equals(stateSet)) {
						stable=true;
					} else {
						stateSet=newStateSet;
					}
				}
				// check if we are in a final state
				BitArray tmp = new BitArray(stateSet);
				tmp.and(finalStates);
				if (!tmp.allZero()) lengths.set(i, true);
			}
			if (!stable) throw new IllegalStateException("State set did not stabelize, increase \"maxLength\"");
			lengthsArray[state]=lengths;
			
			StringBuffer sb = new StringBuffer();
			sb.append(String.format("%d: ", state));
			int m=maxLength;
			for (; m>=1; --m) {
				if (!lengths.get(m-1)) break;
			}
			int n = 0;
			for (boolean b : lengths) {
				if (b) sb.append(String.format("%d,", n));
				++n;
				if (n>m) break;
			}
			sb.append("...");
			Log.println(Log.Level.DEBUG, sb.toString());
		}
		return new Partition(Arrays.asList(lengthsArray));
	}
	
	/** Returns an automaton recognizing the same language, but in a non-overlapping way. */
	// TODO: remove unreachable states
	public CDFA toNonOverlapping() {
		List<int[]> newTransitionTable = new ArrayList<int[]>(transitionTable.size());
		List<Integer> newOutputTable = new ArrayList<Integer>(outputTable.size());
		
		for (int state=0; state<transitionTable.size(); ++state) {
			int[] newTransitions = new int[alphabet.size()];
			if (outputTable.get(state)==0) {
				System.arraycopy(transitionTable.get(state), 0, newTransitions, 0, alphabet.size());
				newOutputTable.add(0);
			} else {
				System.arraycopy(transitionTable.get(0), 0, newTransitions, 0, alphabet.size());
				newOutputTable.add(1);
			}
			newTransitionTable.add(newTransitions);
		}
		
		return new CDFA(alphabet,newTransitionTable,newOutputTable);
	}
	
	/** Returns an automaton recognizing the same language, but returns match positions only (if
	 *  two matches end at the same position, only one is reported. */
	public CDFA toMatchPositionCount() {
		List<int[]> newTransitionTable = new ArrayList<int[]>(transitionTable.size());
		List<Integer> newOutputTable = new ArrayList<Integer>(outputTable.size());
		
		for (int state=0; state<transitionTable.size(); ++state) {
			if (outputTable.get(state)==0) {
				newOutputTable.add(0);
			} else {
				newOutputTable.add(1);
			}
			int[] newTransitions = new int[alphabet.size()];
			System.arraycopy(transitionTable.get(state), 0, newTransitions, 0, alphabet.size());
			newTransitionTable.add(newTransitions);
		}
		return new CDFA(alphabet,newTransitionTable,newOutputTable);
	}
	
	/** Returns a list of states that yield an output > 0. */
	public List<Integer> getOutputStates() {
		List<Integer> result = new ArrayList<Integer>();
		for (int i=0; i<outputTable.size(); ++i) {
			if (outputTable.get(i)>0) {
				result.add(i);
			}
		}
		return result;
	}
	
	@Override
	public int getStateCount() {
		return outputTable.size();
	}
	
	@Override
	public int getStartState() { return 0; }
	
	/** Returns the target of a transition from a given state under a given character.
	 *  @param charIndex Index of character w.r.t. to alphabet. */
	@Override
	public int getTransitionTarget(int sourceState, int charIndex) {
		return transitionTable.get(sourceState)[charIndex];
	}
	
	@Override
	public int getAlphabetSize() {
		return alphabet.size();
	}

	@Override
	public Partition getStatePartition() {
		return new Partition(outputTable);
	}

	public int getStateOutput(int state) { return outputTable.get(state); }
	
	public String toString() {
		StringBuffer sb = new StringBuffer();
		for (int node=0; node<outputTable.size(); ++node) {
			sb.append(String.format("%d(%d): ", node, outputTable.get(node)));
			for (int c=0; c<alphabet.size(); ++c) {
				char chr = alphabet.get(c);
				sb.append(String.format("%s-->%d, ", chr, transitionTable.get(node)[c]));
			}
			sb.append("\n");
		}
		return sb.toString();
	}
	
	/** Returns the maximal possible output (over all states). */
	public int getMaxOutput() {
		int result = 0;
		for (int state=0; state<outputTable.size(); ++state) {
			result = Math.max(result, outputTable.get(state));
		}
		return result;
	}
	
}
